import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized.dynamic as nnqd
import torch.nn.quantized as nnq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Node

from .utils import (
    get_target_type_str,
    getattr_from_fqn,
    return_first_non_observer_node,
)

from .ns_types import (
    NSSingleResultValuesType,
    NSSingleResultType,
)

from typing import List, Optional, Dict, Callable

def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
    return mod.weight.detach()  # type: ignore[operator]

def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
    return mod[0].weight.detach()  # type: ignore[index]

def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
    return mod._weight_bias()[0]  # type: ignore[operator]

def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
    res = []
    for idx, param_name in enumerate(mod._flat_weights_names):  # type: ignore[arg-type]
        if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
            param_value = mod._flat_weights[idx].detach()  # type: ignore[index]
            res.append(param_value)
    return res

def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
    res = []
    for weight_value in mod._all_weight_values:  # type: ignore[union-attr]
        res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
        res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
    return res

def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
    if (
        isinstance(mod, nn.Conv1d) or
        isinstance(mod, nn.Conv2d) or
        isinstance(mod, nn.Conv3d)
    ):
        return mod.weight.detach()
    elif (
        isinstance(mod, nni.ConvReLU1d) or
        isinstance(mod, nni.ConvReLU2d) or
        isinstance(mod, nni.ConvReLU3d)
    ):
        return mod[0].weight.detach()
    else:
        return mod._weight_bias()[0]  # type: ignore[operator]

def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
    if isinstance(mod, nn.Linear):
        return mod.weight.detach()
    elif isinstance(mod, nni.LinearReLU):
        return mod[0].weight.detach()
    else:
        return mod._weight_bias()[0]  # type: ignore[operator]

def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
    # TODO(future PR): make more generic, handle everything
    if isinstance(mod, nn.LSTM):
        res = []
        for idx, param_name in enumerate(mod._flat_weights_names):
            if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
                param_value = mod._flat_weights[idx].detach()
                res.append(param_value)
        return res
    else:
        assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
        res = []
        for weight_value in mod._all_weight_values:
            res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
            res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
        return res

def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
    # traverse backwards from the weight arg, accounting for any observers
    weight_arg_node = node.args[1]
    assert isinstance(weight_arg_node, Node)
    weight_node = return_first_non_observer_node(weight_arg_node, gm)
    assert isinstance(weight_node, Node)
    assert weight_node.op == 'get_attr'
    weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
    return weight.detach()

def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
    # qconv state is arg 1
    qconv_state_node = node.args[1]
    assert isinstance(qconv_state_node, Node)
    assert qconv_state_node.op == 'get_attr'
    qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target)  # type: ignore[arg-type]
    return qconv_state_obj.weight()

def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
    # traverse backwards from the weight arg, accounting for any observers
    # supported patterns:
    # weight -> obs -> linear
    # weight -> to(torch.float16) -> dequantize -> linear
    linear_second_arg = node.args[1]
    assert isinstance(linear_second_arg, Node)

    if linear_second_arg.op == 'call_module':
        # weight -> obs -> linear
        weight_arg_node = node.args[1]
        assert isinstance(weight_arg_node, Node)
        weight_node = weight_arg_node.args[0]
        assert isinstance(weight_node, Node)
        assert weight_node.op == 'get_attr'
        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
        return weight.detach()
    elif linear_second_arg.op == 'call_method':
        # weight -> to(torch.float16) -> dequantize -> linear
        assert linear_second_arg.op == 'call_method'
        dequant_node = node.args[1]
        assert isinstance(dequant_node, Node)
        to_fp16_node = dequant_node.args[0]
        assert isinstance(to_fp16_node, Node)
        # extract the dtype, so we can cast to it before returning
        target_dtype = to_fp16_node.args[1]
        weight_node = to_fp16_node.args[0]
        assert isinstance(weight_node, Node)
        assert weight_node.op == 'get_attr'
        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
        # return the weight with fp16 cast
        return weight.detach().to(target_dtype)
    else:
        assert linear_second_arg.op == 'get_attr'
        weight = getattr_from_fqn(gm, linear_second_arg.target)  # type: ignore[arg-type]
        return weight.detach()

def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
    # packed weight is arg 1
    packed_weight_node = node.args[1]
    assert isinstance(packed_weight_node, Node)
    assert packed_weight_node.op == 'get_attr'
    packed_weight = getattr_from_fqn(gm, packed_weight_node.target)  # type: ignore[arg-type]
    # TODO(future PR): why does packed_weight.unpack() not work?
    (weight, _bias), _name = packed_weight.__getstate__()
    return weight

def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:

    op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
        'call_module': {
            # Conv
            nn.Conv1d: mod_weight_detach,
            nn.Conv2d: mod_weight_detach,
            nn.Conv3d: mod_weight_detach,
            nni.ConvReLU1d: mod_0_weight_detach,
            nni.ConvReLU2d: mod_0_weight_detach,
            nni.ConvReLU3d: mod_0_weight_detach,
            nnq.Conv1d: mod_weight_bias_0,
            nniqat.ConvBn1d: mod_weight_detach,
            nniqat.ConvBnReLU1d: mod_weight_detach,
            nniq.ConvReLU1d: mod_weight_bias_0,
            nnq.Conv2d: mod_weight_bias_0,
            nnqat.Conv2d: mod_weight_detach,
            nniqat.ConvBn2d: mod_weight_detach,
            nniqat.ConvBnReLU2d: mod_weight_detach,
            nniqat.ConvReLU2d: mod_weight_detach,
            nniq.ConvReLU2d: mod_weight_bias_0,
            nnq.Conv3d: mod_weight_bias_0,
            nnqat.Conv3d: mod_weight_detach,
            nniqat.ConvBn3d: mod_weight_detach,
            nniqat.ConvBnReLU3d: mod_weight_detach,
            nniqat.ConvReLU3d: mod_weight_detach,
            nniq.ConvReLU3d: mod_weight_bias_0,
            # Linear
            nn.Linear: mod_weight_detach,
            nnq.Linear: mod_weight_bias_0,
            nni.LinearReLU: mod_0_weight_detach,
            nniq.LinearReLU: mod_weight_bias_0,
            nnqat.Linear: mod_weight_detach,
            nnqd.Linear: mod_weight_bias_0,
            nniqat.LinearReLU: mod_weight_detach,
            nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
            # LSTM
            nn.LSTM: get_lstm_weight,
            nnqd.LSTM: get_qlstm_weight,
        },
        'call_function': {
            # Conv
            F.conv1d: get_conv_fun_weight,
            F.conv2d: get_conv_fun_weight,
            F.conv3d: get_conv_fun_weight,
            toq.conv1d: get_qconv_fun_weight,
            toq.conv2d: get_qconv_fun_weight,
            toq.conv3d: get_qconv_fun_weight,
            toq.conv1d_relu: get_qconv_fun_weight,
            toq.conv2d_relu: get_qconv_fun_weight,
            toq.conv3d_relu: get_qconv_fun_weight,
            # Linear
            F.linear: get_linear_fun_weight,
            toq.linear: get_qlinear_fun_weight,
            toq.linear_relu: get_qlinear_fun_weight,
        },
    }

    return op_to_type_to_weight_extraction_fn

def extract_weight_from_node(
    node: Node,
    gm: GraphModule,
    op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> Optional[NSSingleResultType]:
    res_type = NSSingleResultValuesType.WEIGHT.value

    # Not all graphmodules have _node_name_to_scope, so only fill it
    # out if it exists.
    fqn = None
    if hasattr(gm, '_node_name_to_scope'):
        fqn = gm._node_name_to_scope[node.name][0]  # type: ignore[index]

    if op_to_type_to_weight_extraction_fn is None:
        op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()

    ref_node_type = get_target_type_str(node, gm)
    # for extracting weights, these are always the same
    prev_node_type = ref_node_type

    if node.op == 'call_function':
        function_mapping = op_to_type_to_weight_extraction_fn['call_function']
        for target_fn_type, weight_extraction_fn in function_mapping.items():
            if node.target == target_fn_type:
                weight = weight_extraction_fn(node, gm)
                return {
                    'type': res_type,
                    'values': [weight],
                    'prev_node_name': node.name,
                    'prev_node_target_type': prev_node_type,
                    'ref_node_name': node.name,
                    'ref_node_target_type': ref_node_type,
                    'index_within_arg': 0,
                    'index_of_arg': 0,
                    'fqn': fqn,
                }

    elif node.op == 'call_module':
        # for call_module, we need to look up the modules to do the type check
        assert isinstance(node.target, str)
        mod = getattr_from_fqn(gm, node.target)
        module_mapping = op_to_type_to_weight_extraction_fn['call_module']
        for target_mod_type, weight_extraction_fn in module_mapping.items():
            if type(mod) == target_mod_type:
                weight = weight_extraction_fn(mod)
                return {
                    'type': res_type,
                    'values': [weight],
                    'prev_node_name': node.name,
                    'prev_node_target_type': prev_node_type,
                    'ref_node_name': node.name,
                    'ref_node_target_type': ref_node_type,
                    'index_within_arg': 0,
                    'index_of_arg': 0,
                    'fqn': fqn,
                }

    return None
